from __future__ import annotations
import os
import random
import pandas as pd
from tqdm import tqdm

from utils.util import read_json, write_json
from planning.src.planner import Planner
from planning.src.protocol import Protocol
from planning.src.metrics import Metrics

class Utility:
    def __init__(self, domain, task) -> None:
        self.domain = domain
        self.task = task
        self.planner = Planner(domain=domain, task=task)
        self.data_path = "planning/data/"
        self.dataset_path = f"dataset/planning_picked/{domain}/{task}/"
        self.result_path = f"planning_result/{domain}/{task}/"
        self.dataset = self.load_dataset()
        self.id_list = [protocol.id for protocol in self.dataset]
        self.models = ["flatten_baseline", "atomic_baseline", "atomic_internal", "dsl_internal", "dsl_external", "multi-dsl_internal", "multi-dsl_external"]
        self.evaluate_result_path = self.data_path + f"{self.domain}_{self.task}.csv"
        if os.path.exists(self.evaluate_result_path):
            self.result_csv = pd.read_csv(self.evaluate_result_path, dtype={'ID': str})
        else:
            self.result_csv = pd.DataFrame(columns=["Domain", "Task", "ID", "Model", "Dim-1", "Dim-2", "Dim-3", "Dim-4", "Dim-5", "Dim-6"])
        
    def load_dataset(self):
        dataset = sorted([
            protocol for filename in os.listdir(self.dataset_path)
            if (protocol := Protocol.fromjson(read_json(os.path.join(self.dataset_path, filename))))
        ])
        return dataset
    
    def plan(self, models=[]):
        '''
        models: ["flatten_baseline", "atomic_baseline", "atomic_internal", "dsl_internal", "dsl_external", "multi-dsl_internal", "multi-dsl_external"]
        '''
        models = self.models if not models else models
        for protocol in tqdm(self.dataset, desc=f"{self.domain}-{self.task}"):
            for model in models:
                mode, method = model.split("_")
                dir_path = os.path.join(self.result_path, f"{mode}_{method}")
                os.makedirs(dir_path, exist_ok=True)
                # if os.path.exists(os.path.join(dir_path, f"{protocol.id}.json")):
                #     continue
                novel_protocol: Protocol = self.planner.plan(protocol, mode=mode, method=method)
                if novel_protocol:
                    metadata = novel_protocol.tojson("id", "title", "description", "pseudocode", "program")
                    write_json(os.path.join(dir_path, f"{novel_protocol.id}.json"), metadata)
    
    def evaluate(self, models=[]):
        models = self.models if not models else models
        for model in models:
            res_lines = []
            program_type = "pseudocode"
            if model.startswith("dsl"):
                program_type = "dsl"
            elif model.startswith("multi-dsl"):
                program_type = "multi-dsl"
            if os.path.exists(model_res_path := os.path.join(self.result_path, f"{model}")):
                for filename in tqdm(os.listdir(model_res_path), desc=model):
                    # print(filename)
                    protocol = Protocol.fromjson(read_json(os.path.join(model_res_path, filename)))
                    results = self.single_evaluate(novel_protocol=protocol, novel_protocol_type=program_type)
                    res_lines.append([self.domain, self.task, str(protocol.id), model] + list(results.values()))
            self.__dump_result(res_lines)
    
    def single_evaluate(self, novel_protocol: Protocol, novel_protocol_type: str):
        groundtruth_protocol = Protocol.fromjson(read_json(f"{self.dataset_path}{novel_protocol.id}.json"))
        metrics = Metrics(
            domain=self.domain, 
            novel_protocol=novel_protocol,
            groundtruth_protocol=groundtruth_protocol,
            novel_program_type=novel_protocol_type
        )
        return metrics.get_metrics()
    
    def __dump_result(self, results_list):
        for data in results_list:
            primary_key_cols = ["Domain", "Task", "ID", "Model"]
            mask = (self.result_csv[primary_key_cols[0]] == data[0]) & \
                   (self.result_csv[primary_key_cols[1]] == data[1]) & \
                   (self.result_csv[primary_key_cols[2]] == data[2]) & \
                   (self.result_csv[primary_key_cols[3]] == data[3])
            if self.result_csv[mask].empty:
                self.result_csv.loc[len(self.result_csv)] = data
            else:
                self.result_csv.loc[mask, :] = data
        self.result_csv.to_csv(self.evaluate_result_path, index=False)